import torch
import os
import numpy as np
import util.landscape.util as util
from util.trainer.model import load_model
import h5py
import torch.distributed as dist

def cal_cos(vec1, vec2):
    """ Calculate cosine similarities between two torch tensors or two ndarraies
        Args:
            vec1, vec2: two tensors or numpy ndarraies
    """
    if isinstance(vec1, torch.Tensor) and isinstance(vec1, torch.Tensor):
        return torch.dot(vec1, vec2)/(vec1.norm()*vec2.norm()).item()
    elif isinstance(vec1, np.ndarray) and isinstance(vec2, np.ndarray):
        return np.ndarray.dot(vec1, vec2)/(np.linalg.norm(vec1)*np.linalg.norm(vec2))

def project_1D_coordinate(w, d):
    """ Project vector w to vector d and get the length of the projection.

        Args:
            w: vectorized weights
            d: vectorized direction

        Returns:
            the projection scalar
    """
    assert len(w) == len(d), 'dimension does not match for w and d'
    scale = torch.dot(w.to(torch.float32), d.to(torch.float32))/d.norm()**2
    return scale.item()

def project_2D_coordinate(d, dx, dy, basis):
    """ Project vector d to the plane spanned by dx and dy.

        Args:
            d: vectorized weights
            dx: vectorized direction
            dy: vectorized direction
            orthogonal: whether dx and dy are orthogonal.
        Returns:
            x, y: the projection coordinates
    """
    if basis in ["orthonormal","orthogonal","orthoscale"]:
        x = project_1D_coordinate(d, dx)
        y = project_1D_coordinate(d, dy)
    else:
        # solve the least squre problem: Ax = d
        A = torch.stack([dx, dy], dim=1)
        # 使用 torch.linalg.lstsq() 求解线性方程组，返回一个 LstsqSolution 对象
        solution = torch.linalg.lstsq(A.to(torch.float32), d.unsqueeze(1).to(torch.float32))  # d.unsqueeze(1) 确保 d 是列向量

        # 获取最小二乘解
        x_y = solution.solution  # 解是一个 (2, 1) 张量

        # 解压出 x 和 y
        x, y = x_y[0].item(), x_y[1].item()  # 通过 item() 获取标量值
    return x, y


# def project_trajectory(dx, dy, args):
#     xcoord, ycoord = [], []
#     for root, _, files in os.walk(args.path_to_trajectory):
#         files = sorted(files, key=util.extract_number)
#         init_model=load_model(os.path.join(root, files[0]), args)
#         files.pop(0)
#         for file in files:
#             x, y = project_2D_coordinate(util.get_diff_weights(init_model, load_model(os.path.join(root, file), args)), dx, dy, args.basis)
#             xcoord.append(x)
#             ycoord.append(y)
#     return xcoord, ycoord

def get_projected_trajectories(args, origin, directions, rank):
    if args.animation:
        proj_trajs_file = os.path.join(os.path.join(args.path_to_proj_trajs_file, args.id), "frame_"+str(args.current_epoch)+".h5")
    else:
        proj_trajs_file = os.path.join(args.path_to_proj_trajs_file, args.id+".h5")
    trajectories_coords = {}
    if rank==0 or (rank is None):
        print(f"The trajectories are projected on the plane derived from {args.id} origin and directions.")
    if not os.path.exists(proj_trajs_file):
        for traj in args.trajectories:
            xcoord, ycoord = [], []
            # step_dir=[]
            for root, _, files in os.walk(f"../checkpoints/{traj}"):
                files = sorted(files, key=util.extract_number)
                for i in range(len(files)):
                    x, y = project_2D_coordinate(util.get_diff_weights(origin, load_model(os.path.join(root, files[i]), args).parameters()), directions[0], directions[1], args.basis)
                    xcoord.append(x)
                    ycoord.append(y)
                    # if i>2:
                    #     C=torch.stack(step_dir)
                    #     A=C@C.T
                    #     _, s, _ = torch.linalg.svd(A)
                    #     if rank==0 or (rank is None):
                    #         print(traj, i+1, (s[:2].norm()/s.norm()))
            trajectories_coords[str(traj)]=np.array([xcoord, ycoord])
        if rank==0 or (rank is None):
            if args.animation:
                if not os.path.exists(os.path.join(args.path_to_proj_trajs_file, args.id)):
                    os.makedirs(os.path.join(args.path_to_proj_trajs_file, args.id))
            f= h5py.File(proj_trajs_file, 'w')
            for key, traj in trajectories_coords.items():
                f[key] = traj
            f.close()
    else:
        dist.barrier()
        f = h5py.File(proj_trajs_file, "r")
        for key in f.keys():
            trajectories_coords[key]=f[key][()]
        f.close()
    return trajectories_coords

